import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from openpyxl import Workbook
from typing import Tuple

k=20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/idas/users/liusirui/caculate_steps/data/', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

class GaussianDiffusionTrainer:
    def __init__(self, beta: Tuple[int, int], T: int):
        self.T = T
        self.beta_t = torch.linspace(*beta, T, dtype=torch.float32).to(device)
        alpha_t = 1.0 - k*self.beta_t
        alpha_t_bar = torch.cumprod(alpha_t, dim=0).to(device)
        self.signal_rate = torch.sqrt(alpha_t_bar)
        self.noise_rate = torch.sqrt(1.0 - alpha_t_bar)

    def forward(self, x_0,t):
        t = torch.full((x_0.shape[0],), t, dtype=torch.int64, device=x_0.device)
        epsilon = torch.randn_like(x_0, device=device)
        scaling_factor = torch.sqrt(torch.tensor(1/k, device=x_0.device))
        # predict the noise added from $x_{t-1}$ to $x_t$
        x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 +
               extract(self.noise_rate, t, x_0.shape) * epsilon*scaling_factor)
        return x_t

def extract(v, i, shape):
    out = torch.gather(v, index=i, dim=0).to(device)
    out = out.to(device=device, dtype=torch.float32)
    out = out.view([i.shape[0]] + [1] * (len(shape) - 1))
    return out

def compute_pixel_variance(all_images):
    num_images = all_images.shape[0]
    num_channels = all_images.shape[1]
    height = all_images.shape[2]
    width = all_images.shape[3]
    num_pixels = height * width

    mean_images = all_images.mean(dim=0, keepdim=False).to(device)

    flat_all_images = all_images.view(num_images, num_channels, -1)
    flat_mean_images = mean_images.view(num_channels, -1)

    diff = flat_all_images - flat_mean_images.unsqueeze(0)
    variance = torch.sum(diff * diff, dim=0) / (num_images - 1)

    return mean_images, variance.view(num_channels, height, width)

diffusion_trainer = GaussianDiffusionTrainer(beta=(0.0001, 0.02), T=1000)

results = []
t = 500
all_x_t_images = torch.empty(0).to(device)  
for images, labels in trainloader:
    images = images.to(device)
    x_t = diffusion_trainer.forward(images,t)
    all_x_t_images = torch.cat((all_x_t_images, x_t), dim=0)

mean, variance = compute_pixel_variance(all_x_t_images)
mean_distance_to_zero = torch.norm(mean, p=2)
variance_distance_to_one = torch.norm(variance - (1/k)*torch.ones_like(variance), p=2)
results.append({
    't': t,
    'Mean distance to zero': mean_distance_to_zero.item(),
    'Variance distance to one': variance_distance_to_one.item()
})

df = pd.DataFrame(results)

wb = Workbook()
ws = wb.active

ws.append(list(df.columns))

for index, row in df.iterrows():
    ws.append(list(row))

wb.save('results.xlsx')